import argparse
import os
import random
import numpy as np
import torch
import ast
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments
)
from torch import nn

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def collate_fn(batch, tokenizer):
    """Custom collate function for the CLUTRR dataset."""
    inputs = []
    labels = []

    for i, item in enumerate(batch):
        story = item.get("story", "")
        query = item.get("query")
        target_text = item.get("target_text", "")

        # Build query string
        if query is not None:
            if isinstance(query, str) and query.strip().startswith("("):
                try:
                    parsed = ast.literal_eval(query)
                    if isinstance(parsed, (list, tuple)) and len(parsed) >= 2:
                        query_str = f"What is the relationship between {parsed[0]} and {parsed[1]}? Answer:"
                    else:
                        query_str = f"What is the relationship between {query}? Answer:"
                except Exception:
                    query_str = f"What is the relationship between {query}? Answer:"
            else:
                query_str = f"What is the relationship between {query}? Answer:"
        else:
            query_str = "What is the relationship? Answer:"

        prompt = f"Story: {story}\nQuery: {query_str}"
        combined = prompt + " " + target_text

        tokenized_combined = tokenizer(combined, add_special_tokens=True, return_tensors=None)
        tokenized_prompt = tokenizer(prompt, add_special_tokens=True, return_tensors=None)

        prompt_len = len(tokenized_prompt["input_ids"])
        input_ids = tokenized_combined["input_ids"]
        sample_labels = [-100] * prompt_len + input_ids[prompt_len:]

        inputs.append(input_ids)
        labels.append(sample_labels)

    batch_enc = tokenizer.pad(
        {"input_ids": inputs},
        padding=True,
        return_attention_mask=True,
        return_tensors="pt",
    )

    max_len = batch_enc["input_ids"].size(1)
    labels_padded = [l + [-100] * (max_len - len(l)) for l in labels]
    batch_enc["labels"] = torch.tensor(labels_padded, dtype=torch.long)

    return batch_enc


class GPT2WithResidualBeta(nn.Module):
    """GPT2 model with learnable beta parameters for residual connections."""
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model
        
        # Freeze all parameters of the base model
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        # Add beta parameter for each layer
        num_layers = len(self.base_model.transformer.h)
        self.betas = nn.Parameter(torch.ones(num_layers))
        
        # Register hooks for each transformer block
        self.hooks = []
        self._register_hooks()
        
    def _register_hooks(self):
        """Register hooks to modify residual connections with beta parameters."""
        def get_hook_fn(layer_idx):
            def hook_fn(module, input_tensors, output_tensors):
                # Apply beta to the output: output = beta * output
                beta = self.betas[layer_idx]
                modified_output = output_tensors * beta
                return modified_output
            return hook_fn
        
        # Register hooks for all transformer blocks
        for i, block in enumerate(self.base_model.transformer.h):
            hook = block.mlp.register_forward_hook(get_hook_fn(i))
            self.hooks.append(hook)
    
    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        """Forward pass delegating to the base model."""
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            **kwargs
        )
        return outputs
    
    def print_beta_values(self):
        """Print the current beta parameter values."""
        beta_values = self.betas.detach().cpu().numpy()
        print("Beta values:")
        for i, beta in enumerate(beta_values):
            print(f"Layer {i}: {beta:.4f}")
        print("-" * 30)

def process_test_data(test_dataset):
    """Process the test dataset into a format suitable for evaluation."""
    processed_examples = []
    
    for example in test_dataset:
        story = example.get("story", "")
        target_text = example.get("target_text", "")
        query = example.get("query", None)
        
        # Build query string
        query_str = "What is the relationship? Answer:"
        if query is not None:
            if isinstance(query, str) and query.strip().startswith("("):
                try:
                    parsed_query = ast.literal_eval(query)
                    if isinstance(parsed_query, (list, tuple)) and len(parsed_query) >= 2:
                        query_str = f"What is the relationship between {parsed_query[0]} and {parsed_query[1]}? Answer:"
                except Exception as e:
                    print(f"Error parsing query '{query}': {e}")
                    query_str = f"What is the relationship between {query}? Answer:"
            else:
                query_str = f"What is the relationship between {query}? Answer:"
        
        prompt = f"Story: {story}\nQuery: {query_str}"
        combined = prompt + " " + target_text
        
        processed_examples.append({
            "prompt": prompt,
            "combined": combined,
            "target_text": target_text
        })
    
    return processed_examples

def evaluate_model(model, test_dataset, tokenizer, device):
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)

    """Evaluate model on test dataset using a dedicated test data processor."""
    model.eval()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    print(f"Test dataset length: {len(test_dataset)}")
    
    # Process the test data with our dedicated function
    processed_examples = process_test_data(test_dataset)
    print(f"Processed {len(processed_examples)} test examples")
    
    # Print an example for debugging
    if processed_examples:
        print("Example processed test item:")
        print(f"Prompt: {processed_examples[0]['prompt'][:100]}...")
        print(f"Target: {processed_examples[0]['target_text']}")
    
    with torch.no_grad():
        for batch_idx in range(0, len(processed_examples), 16):  # Small eval batch size
            batch = processed_examples[batch_idx:batch_idx + 16]
            
            # Process each example in the batch
            inputs = []
            labels = []
            
            for item in batch:
                prompt = item["prompt"]
                combined = item["combined"]
                
                tokenized_combined = tokenizer(combined, add_special_tokens=True, return_tensors=None)
                tokenized_prompt = tokenizer(prompt, add_special_tokens=True, return_tensors=None)
                
                prompt_len = len(tokenized_prompt["input_ids"])
                input_ids = tokenized_combined["input_ids"]
                sample_labels = [-100] * prompt_len + input_ids[prompt_len:]
                
                inputs.append(input_ids)
                labels.append(sample_labels)
            
            # Create batch encodings
            batch_enc = tokenizer.pad(
                {"input_ids": inputs},
                padding=True,
                return_attention_mask=True,
                return_tensors="pt",
            )
            
            max_len = batch_enc["input_ids"].size(1)
            labels_padded = [l + [-100] * (max_len - len(l)) for l in labels]
            batch_enc["labels"] = torch.tensor(labels_padded, dtype=torch.long)
            
            # Move tensors to device
            input_ids = batch_enc["input_ids"].to(device)
            attention_mask = batch_enc["attention_mask"].to(device)
            labels = batch_enc["labels"].to(device)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            batch_loss = outputs.loss.item()
            total_loss += batch_loss * len(batch)
            
            # Calculate accuracy
            logits = outputs.logits  # (batch_size, seq_len, vocab_size)
            first_label_positions = (labels != -100).float().argmax(dim=1)  # (batch_size,)
            
            for i in range(logits.size(0)):
                pos = first_label_positions[i].item()
                if pos > 0 and pos < labels.size(1) and labels[i, pos] != -100:  # Avoid edge cases
                    pred_token = logits[i, pos - 1].argmax(dim=-1)  # prediction for next token
                    true_token = labels[i, pos]
                    
                    if pred_token == true_token:
                        correct_predictions += 1
                    total_samples += 1
            
            # Print debug info for first batch
            if batch_idx == 0:
                print(f"First batch example:")
                example_idx = 0
                example_prompt = tokenizer.decode(input_ids[example_idx][:first_label_positions[example_idx].item()])
                example_true_token = tokenizer.decode(labels[example_idx, first_label_positions[example_idx].item()].unsqueeze(0))
                example_pred_token = tokenizer.decode(logits[example_idx, first_label_positions[example_idx].item() - 1].argmax(dim=-1).unsqueeze(0))
                print(f"Prompt:\n{example_prompt}")
                print(f"True next token: '{example_true_token}'")
                print(f"Predicted next token: '{example_pred_token}'")
                print("-" * 30)
    
    avg_loss = total_loss / len(processed_examples) if processed_examples else 0
    accuracy = correct_predictions / total_samples if total_samples > 0 else 0
    
    print(f"Total correct: {correct_predictions}, Total samples: {total_samples}")
    print(f"Final Accuracy: {accuracy:.4f}")
    
    return avg_loss, accuracy

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--checkpoint_dir", type=str, required=True)
    parser.add_argument("--output_dir", type=str, default="path/to/your/folder")
    args = parser.parse_args()

    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)

    # Save training configuration
    config_path = os.path.join(args.output_dir, "path/to/your/file")
    with open(config_path, "w") as f:
        for arg in vars(args):
            f.write(f"{arg}: {getattr(args, arg)}\n")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load CLUTRR dataset
    dataset = load_dataset("CLUTRR/v1", "gen_train23_test2to10")

    # Filter out examples with task names "1.2" and "1.3"
    def task_filter(example):
        task_name = example.get("task_name", "")
        return not (task_name.endswith("1.2") or task_name.endswith("1.3"))

    train_ds = dataset["test"].filter(task_filter)
    # train_ds = dataset["train"]
    test_dataset = dataset["test"].filter(task_filter)
    # test_dataset = dataset["train"]
    
    # Load model from checkpoint
    print(f"Loading model from checkpoint: {args.checkpoint_dir}")
    base_model = AutoModelForCausalLM.from_pretrained(args.checkpoint_dir)
    base_model = base_model.to(device)
    
    # Evaluate model
    print("Evaluating model on test dataset...")
    test_loss, test_accuracy = evaluate_model(base_model, test_dataset, tokenizer, device)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")

    # Save metrics
    with open(os.path.join(args.output_dir, "path/to/your/file"), "w") as f:
        f.write(f"Test Loss: {test_loss:.4f}\n")
        f.write(f"Test Accuracy: {test_accuracy:.4f}\n")

    # Create model with beta parameters
    model = GPT2WithResidualBeta(base_model)
    model.to(device)
    
    # Print initial beta values
    print("Initial beta values:")
    model.print_beta_values()

    # Training arguments
    training_args = TrainingArguments(
        seed=args.seed,
        data_seed=args.seed,
        output_dir=args.output_dir,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        learning_rate=args.lr,
        weight_decay=0.01,
        remove_unused_columns=False,
        report_to=["none"],
        save_strategy="no"
    )

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        data_collator=lambda b: collate_fn(b, tokenizer),
    )

    # Train the model (only beta parameters will be updated)
    trainer.train()
    
    # Print final beta values
    print("Final beta values after training:")
    model.print_beta_values()
    
    # Save final beta values
    torch.save({
        "betas": model.betas.detach().cpu(),
    }, os.path.join(args.output_dir, "path/to/your/file"))
    
    print(f"Training complete. Beta parameters saved.")

    # Save beta values to text file
    with open(os.path.join(args.output_dir, "path/to/your/file"), "w") as f:
        for i, beta in enumerate(model.betas.detach().cpu().numpy()):
            f.write(f"Layer {i}: {beta:.6f}\n")

    # Evaluate model
    print("Evaluating model on test dataset...")
    test_loss, test_accuracy = evaluate_model(model, test_dataset, tokenizer, device)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")

    # Save metrics
    with open(os.path.join(args.output_dir, "path/to/your/file"), "w") as f:
        f.write(f"Test Loss: {test_loss:.4f}\n")
        f.write(f"Test Accuracy: {test_accuracy:.4f}\n")


if __name__ == "__main__":
    main()